# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import numpy
import scipy.sparse

from qaths.utils.math import kron, complete_by_zeros as complete_matrix_by_zeros


def random_complex_array(shape) -> numpy.ndarray:
    real = numpy.random.rand(*shape)
    imag = numpy.random.rand(*shape)
    return real + 1.0j * imag


def F(n: int) -> numpy.ndarray:
    return numpy.diag(numpy.arange(2 ** n))


_P0 = numpy.array([[1, 0], [0, 0]])
_P1 = numpy.array([[0, 0], [0, 1]])
H = numpy.array([[1, 1], [1, -1]]) / numpy.sqrt(2)
X = numpy.array([[0, 1], [1, 0]])
Y = numpy.array([[0, -1.0j], [1.0j, 0]])
Z = numpy.array([[1, 0], [0, -1]])
I = numpy.array([[1, 0], [0, 1]])

CX = numpy.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])


def ctrl(gate: numpy.ndarray, ctrl: int, trgt: int, total_qubit: int) -> numpy.ndarray:
    ret = numpy.zeros((2 ** total_qubit, 2 ** total_qubit), dtype=numpy.complex)
    for ctrl_value in [0, 1]:
        tmp = 1
        for qubit in range(total_qubit):
            if qubit == ctrl:
                tmp = kron(tmp, _P0 if ctrl_value == 0 else _P1)
            elif qubit == trgt:
                tmp = kron(tmp, I if ctrl_value == 0 else gate)
            else:
                tmp = kron(tmp, I)
        ret += tmp
    return ret


def Rzz(theta: float) -> numpy.ndarray:
    return numpy.array([[numpy.exp(1.0j * theta), 0], [0, numpy.exp(1.0j * theta)]])


def construct_Dirichlet_Hamiltonians_1D(number_of_discretisation_points: int):
    """Construct the sparse matrices that decompose the Hamiltonian to simulate.

    The Hamiltonian matrix to simulate to solve the wave equation is a 2-sparse
    matrix with only 1 and -1 entries.

    :param number_of_discretisation_points: The number of points used in the
        discretisation.
    :return: 2 Hamiltonian matrices that sums to the Hamiltonian described in the
        paper https://arxiv.org/pdf/1711.05394.pdf.
    """
    considered_points = number_of_discretisation_points - 2

    I = numpy.hstack((numpy.arange(considered_points), numpy.arange(considered_points)))
    J = numpy.hstack(
        (numpy.arange(considered_points), numpy.arange(considered_points) + 1)
    )
    vals = numpy.hstack(
        (
            numpy.ones((1,), dtype=numpy.int),
            -numpy.ones((considered_points - 1,), dtype=numpy.int),
            numpy.ones((considered_points,), dtype=numpy.int),
        )
    )

    Bs = [
        scipy.sparse.coo_matrix(
            (vals[:considered_points], (I[:considered_points], J[:considered_points])),
            shape=(considered_points, considered_points + 1),
        ),
        scipy.sparse.coo_matrix(
            (vals[considered_points:], (I[considered_points:], J[considered_points:])),
            shape=(considered_points, considered_points + 1),
        ),
    ]

    Hs = [
        complete_matrix_by_zeros(
            scipy.sparse.vstack(
                (
                    scipy.sparse.hstack(
                        (
                            scipy.sparse.coo_matrix(
                                (B.shape[0], B.shape[0]), dtype=B.dtype
                            ),
                            B,
                        )
                    ),
                    scipy.sparse.hstack(
                        (
                            B.conj().transpose(),
                            scipy.sparse.coo_matrix(
                                (B.shape[1], B.shape[1]), dtype=B.dtype
                            ),
                        )
                    ),
                )
            )
        ).tocsr()
        for B in Bs
    ]

    return Hs
